local super = require "Regression"

SigmoidRegression = super:new()

local _exp = math.exp

local _h = function(x, x0, r)
    return 1 / (1 + _exp(-r * (x - x0)))
end

function SigmoidRegression:init()
    super.init(self)
    self._coeff0 = nil
    self._coeff1 = nil
    self._coeff2 = nil
end

function SigmoidRegression:finish()
    local n = self._n
    local xScale = 1 / (self._xmax - self._xmin)
    local xOffset = 1 - self._xmin / (self._xmax - self._xmin)
    local xs = Sequence:newWithArray(self._xs)
        :sort(Sequence:newWithArray(self._xs))
        :map(function(x) return x * xScale + xOffset end)
        :toArray()
    local ys = Sequence:newWithArray(self._ys):sort(Sequence:newWithArray(self._xs)):toArray()
    
    local ysum = 0
    for index = 1, n do
        local y = ys[index]
        ysum = ysum + y
    end
    local sstot = 0
    for index = 1, n do
        local y = ys[index]
        sstot = sstot + (y - ysum / n) ^ 2
    end
    
    local midpointSlope = 0
    local midpointIndex = 1
    local range = math.ceil(math.sqrt(n) / 2)
    for index = 1 + range, n - range do
        if xs[index + range] > xs[index - range] then
            local slope = (ys[index + range] - ys[index - range]) / (xs[index + range] - xs[index - range])
            if slope > midpointSlope then
                midpointSlope = slope
                midpointIndex = index
            end
        end
    end
    
    local k = self._ymax
    local r = 4 * midpointSlope / k
    local x0 = xs[midpointIndex]
    
    local dots = function(r, x0)
        local hoh, hom, mom = 0, 0, 0
        for index = 1, n do
            local h = _h(xs[index], x0, r)
            local m = ys[index]
            hoh = hoh + h * h
            hom = hom + h * m
            mom = mom + m * m
        end
        return hoh, hom, mom
    end
    
    local getErr = function(r, x0)
        local hoh, hom, mom = dots(r, x0)
        return mom - hom ^ 2 / hoh
    end
    
    local testPoints = {
        { r, x0 },
        { r * 1.125, x0 },
        { r, x0 * 1.125 },
    }
    local search = SimplexSearch:new(testPoints, function(solution) return getErr(solution[1], solution[2]) end)
    local r2
    for iter = 1, 64 do
        local iteration, error = search:iterate()
        r = iteration[1]
        x0 = iteration[2]
        r2 = 1 - error / sstot
    end
    
    local hoh, hom = dots(r, x0)
    k = hom / hoh
    
    self._coeff2 = k
    self._coeff1 = r * xScale
    self._coeff0 = (x0 - xOffset) / xScale
    self._r2 = r2
    
    return true
end

function SigmoidRegression:getEquation()
    local equation = 'y = '
    if self._coeff2 == 0 then
        equation = equation .. '0'
    else
        equation = equation .. string.format('%g', self._coeff2)
        equation = equation .. ' / '
        equation = equation .. string.format('(1 + e^{-%g(x - %g)})', self._coeff1, self._coeff0)
    end
    return equation
end

function SigmoidRegression:getValue(x)
    return self._coeff2 * _h(x, self._coeff0, self._coeff1)
end

return SigmoidRegression
